//
// Created by Administrator on 2023/11/20.
//
#include <torch/torch.h>
#include "rasterize_points.h"

using namespace std;

void param_normal_data() {
    // 正常的spl能够输出彩色图像的输入参数
    float means3D_arr[] = {1.9591, -0.4911, 1.3208,
                           1.9800, 0.0902, 1.2955,
                           -1.5193, -0.0213, -1.2584,
                           3.6540, -0.1400, 2.1186,
                           0.6641, -0.8580, 1.2145,
                           -2.6418, -0.9816, -1.4900,
                           -0.2970, -0.9376, -0.1711,
                           -2.5155, -0.8410, -1.4656,
                           2.2086, 0.9839, 3.0780,
                           1.1933, 0.6980, -0.1107};
    torch::Tensor means3D = torch::from_blob(means3D_arr, {10, 3}, torch::dtype(torch::kFloat)).to(
            torch::kCUDA);  // 创建10行3列的数组
    float means2D_arr[] = {0, 0, 0,
                           0, 0, 0,
                           0, 0, 0,
                           0, 0, 0,
                           0, 0, 0,
                           0, 0, 0,
                           0, 0, 0,
                           0, 0, 0,
                           0, 0, 0,
                           0, 0, 0};
    torch::Tensor means2D = torch::from_blob(means2D_arr, {10, 3}, torch::dtype(torch::kFloat)).to(torch::kCUDA);
    float shs_arr[] = {-0.4796, -0.0626, -0.2294,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,

                       0.9801, 0.8132, 0.2850,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,

                       -0.0904, 0.1599, -0.0209,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,

                       -0.0904, -0.0765, -0.0487,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,

                       0.5352, 0.8132, 1.0635,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,

                       -0.0487, 0.0904, 0.0765,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,

                       -0.2572, -0.2294, -0.3545,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,

                       -0.4379, -0.0070, -0.1877,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,

                       1.7029, 1.7307, 1.7168,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,

                       1.6056, 1.6334, 1.5917,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,
                       0, 0, 0,};
    torch::Tensor shs = torch::from_blob(shs_arr, {10, 16, 3}, torch::dtype(torch::kFloat)).to(torch::kCUDA);
    torch::Tensor colors_precomp = torch::tensor({0}, torch::dtype(torch::kFloat)).to(torch::kCUDA);
    float opacities_arr[] = {0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000};
    torch::Tensor opacities = torch::from_blob(opacities_arr, {10, 1}, torch::dtype(torch::kFloat)).to(torch::kCUDA);
    float scales_arr[] = {0.0060, 0.0060, 0.0060,
                          0.0083, 0.0083, 0.0083,
                          0.0078, 0.0078, 0.0078,
                          0.0076, 0.0076, 0.0076,
                          0.0208, 0.0208, 0.0208,
                          0.0154, 0.0154, 0.0154,
                          0.0123, 0.0123, 0.0123,
                          0.0120, 0.0120, 0.0120,
                          0.0317, 0.0317, 0.0317,
                          0.0111, 0.0111, 0.0111};
    torch::Tensor scales = torch::from_blob(scales_arr, {10, 3}, torch::dtype(torch::kFloat)).to(torch::kCUDA);
    float rotations_arr[] = {1., 0., 0., 0.,
                             1., 0., 0., 0.,
                             1., 0., 0., 0.,
                             1., 0., 0., 0.,
                             1., 0., 0., 0.,
                             1., 0., 0., 0.,
                             1., 0., 0., 0.,
                             1., 0., 0., 0.,
                             1., 0., 0., 0.,
                             1., 0., 0., 0.};
    torch::Tensor rotations = torch::from_blob(rotations_arr, {10, 4}, torch::dtype(torch::kFloat)).to(torch::kCUDA);
    torch::Tensor cov3D_precomp = torch::tensor({0}, torch::dtype(torch::kFloat)).to(torch::kCUDA);
    torch::Tensor background = torch::tensor({0, 0, 0}, torch::dtype(torch::kFloat)).to(torch::kCUDA);
    torch::Tensor campos = torch::tensor({3.1687, 0.1043, 0.9233}, torch::dtype(torch::kFloat)).to(torch::kCUDA);
    bool debug = false;
    int image_height = 545;
    int image_width = 980;
    bool prefiltered = false;
    float scale_modifier = 1.0;
    int sh_degree = 0;
    float tanfovx = 0.8446965112441064;
    float tanfovy = 0.4679476755039769;
    float projmatrix_arr[] = {-8.2119e-02, 2.0427e-02, -9.9765e-01, -9.9755e-01,
                              1.8253e-01, 2.1114e+00, -1.2537e-03, -1.2535e-03,
                              1.1668e+00, -3.2887e-01, -7.0017e-02, -7.0010e-02,
                              -8.3616e-01, 1.8587e-02, 3.2160e+00, 3.2257e+00};
    torch::Tensor projmatrix = torch::from_blob(projmatrix_arr, {4, 4}, torch::dtype(torch::kFloat)).to(torch::kCUDA);
    float viewmatrix_arr[] = {-6.9366e-02, 9.5589e-03, -9.9755e-01, 0.0000e+00,
                              1.5418e-01, 9.8804e-01, -1.2535e-03, 0.0000e+00,
                              9.8560e-01, -1.5389e-01, -7.0010e-02, 0.0000e+00,
                              -7.0630e-01, 8.6978e-03, 3.2257e+00, 1.0000e+00};
    torch::Tensor viewmatrix = torch::from_blob(viewmatrix_arr, {4, 4}, torch::dtype(torch::kFloat)).to(torch::kCUDA);

    tuple<int, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> result = RasterizeGaussiansCUDA(
            background,  // 背景颜色
            means3D,  // [Point_num, 3] 高斯中心点坐标 x y z
            colors_precomp,  // [0,] 预先计算的颜色
            opacities,  // [Point_num, 1] 每个高斯点的透明度
            scales,  // [Point_num, 3]  xyz轴的缩放尺度
            rotations,  // [Point_num, 4]  四元数？
            scale_modifier,  // int 1
            cov3D_precomp,  // [0,] 预先计算的3D协方差
            viewmatrix,  // [4, 4]
            projmatrix,  // [4, 4]
            tanfovx,  // float
            tanfovy,  // float
            image_height,  // int
            image_width,  // int
            shs,  // [Point_num, 16, 3]
            sh_degree,  // 0
            campos,  // [3,]
            prefiltered,  // false
            debug  // false
    );
}


void param_DFC2019_data() {
    // spl无法正常输出彩色图像的 DFC2019输入参数
    float means3D_arr[] = {1.0618e+01, -7.0330e+01, 5.5000e+00,
                           -5.8543e+01, -9.9306e+01, 2.4090e+00,
                           -1.1014e+02, -2.2178e+01, 3.8681e+01,
                           -2.8026e+01, 3.5305e+01, 3.7534e+00,
                           1.1613e+02, 3.9842e+01, 1.8743e+00,
                           -9.0227e+01, -6.0558e+01, 1.7636e+01,
                           8.7585e+01, -3.4586e+01, -1.8229e+00,
                           -1.1770e+01, 7.8075e+01, 4.1937e-03,
                           2.4226e+01, 1.2927e+02, 8.0854e+00,
                           1.1514e+02, 1.1482e+02, -3.5985e+00};
    torch::Tensor means3D = torch::from_blob(means3D_arr, {10, 3}, torch::dtype(torch::kFloat)).to(
            torch::kCUDA);  // 创建10行3列的数组
    float means2D_arr[] = {0, 0, 0,
                           0, 0, 0,
                           0, 0, 0,
                           0, 0, 0,
                           0, 0, 0,
                           0, 0, 0,
                           0, 0, 0,
                           0, 0, 0,
                           0, 0, 0,
                           0, 0, 0};
    torch::Tensor means2D = torch::from_blob(means2D_arr, {10, 3}, torch::dtype(torch::kFloat)).to(torch::kCUDA);
    float shs_arr[] = {1.4805, 1.5083, 1.4944,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.6464, 0.3962, 0.2433,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       -1.2998, -0.9245, -0.6186,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.1738, 0.1599, 0.2155,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       -0.0487, 0.2989, 0.1599,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.3684, 0.5352, 0.6742,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       -0.2294, 0.1738, 0.0765,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.4796, 0.4379, 0.4101,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       -0.5213, -0.6325, -0.7159,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       -0.2433, -0.0765, -0.0626,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000,
                       0.0000, 0.0000, 0.0000};
    torch::Tensor shs = torch::from_blob(shs_arr, {10, 16, 3}, torch::dtype(torch::kFloat)).to(torch::kCUDA);
    torch::Tensor colors_precomp = torch::tensor({0}, torch::dtype(torch::kFloat)).to(torch::kCUDA);
    float opacities_arr[] = {0.01000, 0.01000, 0.01000, 0.01000, 0.01000, 0.01000, 0.01000, 0.01000, 0.01000, 0.01000};
    torch::Tensor opacities = torch::from_blob(opacities_arr, {10, 1}, torch::dtype(torch::kFloat)).to(torch::kCUDA);
    float scales_arr[] = {1.4083, 1.4083, 1.4083,
                          2.2345, 2.2345, 2.2345,
                          2.9836, 2.9836, 2.9836,
                          1.2952, 1.2952, 1.2952,
                          1.6332, 1.6332, 1.6332,
                          0.7709, 0.7709, 0.7709,
                          2.1094, 2.1094, 2.1094,
                          1.7761, 1.7761, 1.7761,
                          1.8701, 1.8701, 1.8701,
                          2.0308, 2.0308, 2.0308};
    torch::Tensor scales = torch::from_blob(scales_arr, {10, 3}, torch::dtype(torch::kFloat)).to(torch::kCUDA);
    float rotations_arr[] = {1., 0., 0., 0.,
                             1., 0., 0., 0.,
                             1., 0., 0., 0.,
                             1., 0., 0., 0.,
                             1., 0., 0., 0.,
                             1., 0., 0., 0.,
                             1., 0., 0., 0.,
                             1., 0., 0., 0.,
                             1., 0., 0., 0.,
                             1., 0., 0., 0.};
    torch::Tensor rotations = torch::from_blob(rotations_arr, {10, 4}, torch::dtype(torch::kFloat)).to(torch::kCUDA);
    torch::Tensor cov3D_precomp = torch::tensor({0}, torch::dtype(torch::kFloat)).to(torch::kCUDA);
    torch::Tensor background = torch::tensor({0, 0, 0}, torch::dtype(torch::kFloat)).to(torch::kCUDA);
    torch::Tensor campos = torch::tensor({90630.8438, 236316.4062, 694838.3750}, torch::dtype(torch::kFloat)).to(
            torch::kCUDA);
    bool debug = false;
    int image_height = 815;
    int image_width = 746;
    bool prefiltered = false;
    float scale_modifier = 1.0;
    int sh_degree = 3;
    float tanfovx = 0.00017932269873479953;
    float tanfovy = 0.00021399145328289846;
    float projmatrix_arr[] = {3.2248e+01, 4.4823e+03, 2.8279e-01, 2.8277e-01,
                              5.3391e+03, 3.5660e+02, -2.7846e-01, -2.7843e-01,
                              -1.6096e+03, 1.2726e+03, -9.1798e-01, -9.1789e-01,
                              -1.4621e+08, -1.3748e+09, 6.7802e+05, 6.7795e+05};
    torch::Tensor projmatrix = torch::from_blob(projmatrix_arr, {4, 4}, torch::dtype(torch::kFloat)).to(torch::kCUDA);
    float viewmatrix_arr[] = {5.7828e-03, 9.5917e-01, 2.8277e-01, 0.0000e+00,
                              9.5742e-01, 7.6310e-02, -2.7843e-01, 0.0000e+00,
                              -2.8864e-01, 2.7234e-01, -9.1789e-01, 0.0000e+00,
                              -2.6220e+04, -2.9419e+05, 6.7795e+05, 1.0000e+00};
    torch::Tensor viewmatrix = torch::from_blob(viewmatrix_arr, {4, 4}, torch::dtype(torch::kFloat)).to(torch::kCUDA);

    tuple<int, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> result = RasterizeGaussiansCUDA(
            background,  // 背景颜色
            means3D,  // [Point_num, 3] 高斯中心点坐标 x y z
            colors_precomp,  // [0,] 预先计算的颜色
            opacities,  // [Point_num, 1] 每个高斯点的透明度
            scales,  // [Point_num, 3]  xyz轴的缩放尺度
            rotations,  // [Point_num, 4]  四元数？
            scale_modifier,  // int 1
            cov3D_precomp,  // [0,] 预先计算的3D协方差
            viewmatrix,  // [4, 4]
            projmatrix,  // [4, 4]
            tanfovx,  // float
            tanfovy,  // float
            image_height,  // int
            image_width,  // int
            shs,  // [Point_num, 16, 3]
            sh_degree,  // 0
            campos,  // [3,]
            prefiltered,  // false
            debug  // false
    );
}


int main() {
    param_DFC2019_data();
    return 0;
}